import rdkit.Chem as Chem
from mol_graph import bond_fdim, bond_features
import numpy as np

BOND_TYPE = ["NOBOND", Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE,
             Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC]
N_BOND_CLASS = len(BOND_TYPE)
binary_fdim = 4 + bond_fdim
INVALID_BOND = -1


def get_bin_feature(r, max_natoms):
    '''
    This function is used to generate descriptions of atom-atom relationships, including
    the bond type between the atoms (if any) and whether they belong to the same molecule.
    It is used in the global attention mechanism.
    '''
    comp = {}
    # split molecules according '.'
    for i, s in enumerate(r.split('.')):
        mol = Chem.MolFromSmiles(s)
        for atom in mol.GetAtoms():
            # current atom belong to which molecules?
            # using comp to mark
            comp[atom.GetIntProp('molAtomMapNumber') - 1] = i
    # the number of molecules in the reaction part
    n_comp = len(r.split('.'))
    # construct smiles to the whole reaction part
    rmol = Chem.MolFromSmiles(r)
    n_atoms = rmol.GetNumAtoms()
    bond_map = {}
    # construct bond_map in the form of key:tuple value:bond_type
    for bond in rmol.GetBonds():
        a1 = bond.GetBeginAtom().GetIntProp('molAtomMapNumber') - 1
        a2 = bond.GetEndAtom().GetIntProp('molAtomMapNumber') - 1
        bond_map[(a1, a2)] = bond_map[(a2, a1)] = bond

    features = []
    for i in range(max_natoms):
        for j in range(max_natoms):
            # feature init with zero
            f = np.zeros((binary_fdim,))
            # if abnormal, feature append default 0
            # f[0] = 0
            if i >= n_atoms or j >= n_atoms or i == j:
                features.append(f)
                continue
            # if react, feature append bond_type
            # f[0] = 0
            if (i, j) in bond_map:
                bond = bond_map[(i, j)]
                f[1:1 + bond_fdim] = bond_features(bond)
            # if common and there is no reaction at this pair
            # f[0] = 1, bond_type is default 0
            else:
                f[0] = 1.0
            # atoms belong to different molecules
            f[-4] = 1.0 if comp[i] != comp[j] else 0.0
            # atoms belong to the same molecules
            f[-3] = 1.0 if comp[i] == comp[j] else 0.0
            # there is only one molecule in the reaction part
            f[-2] = 1.0 if n_comp == 1 else 0.0
            # there are more than one molecules in the reaction part
            f[-1] = 1.0 if n_comp > 1 else 0.0
            features.append(f)
    return np.vstack(features).reshape((max_natoms, max_natoms, binary_fdim))


bo_to_index = {0.0: 0, 1: 1, 2: 2, 3: 3, 1.5: 4}
nbos = len(bo_to_index)


def get_bond_label(r, edits, max_natoms):
    # get smiles of the whole reaction part
    rmol = Chem.MolFromSmiles(r)
    # get the total num of atoms in reaction part
    n_atoms = rmol.GetNumAtoms()
    # init the reaction map matrix with zero
    rmap = np.zeros((max_natoms, max_natoms, nbos))
    # import pdb
    # pdb.set_trace()
    # split the edits with the mark ";"
    for s in edits.split(';'):
        a1, a2, bo = s.split('-')
        x = min(int(a1) - 1, int(a2) - 1)
        y = max(int(a1) - 1, int(a2) - 1)
        z = bo_to_index[float(bo)]
        # mark the rect atom pair with 1 in reaction map matrix
        rmap[x, y, z] = rmap[y, x, z] = 1

    labels = []
    sp_labels = []
    for i in range(max_natoms):
        for j in range(max_natoms):
            for k in range(len(bo_to_index)):
                # padding the rest part with INVALID_BOND
                if i == j or i >= n_atoms or j >= n_atoms:
                    labels.append(INVALID_BOND)  # mask
                # if normal
                else:
                    # labels is a 0 1 list
                    # labels record all reaction possibility
                    # labels' shape is (max_natoms * max_natoms * 1)
                    labels.append(rmap[i, j, k])
                    # sp_label record react location with int number
                    if rmap[i, j, k] == 1:
                        sp_labels.append(i * max_natoms * nbos + j * nbos + k)
                        # TODO: check if this is consistent with how TF does flattening
    return np.array(labels), sp_labels


def get_all_batch(re_list):
    mol_list = []
    max_natoms = 0
    # calculate the max number of atoms in each reaction part
    for r, e in re_list:
        rmol = Chem.MolFromSmiles(r)
        mol_list.append((r, e))
        if rmol.GetNumAtoms() > max_natoms:
            max_natoms = rmol.GetNumAtoms()
    labels = []
    features = []
    sp_labels = []
    for r, e in mol_list:
        # get label and sp_label
        l, sl = get_bond_label(r, e, max_natoms)
        # pack binary
        features.append(get_bin_feature(r, max_natoms))
        # pack label
        labels.append(l)
        # pack sp_label
        sp_labels.append(sl)
    return np.array(features), np.array(labels), sp_labels


# bmask
# get bond_mask used for masking scores
def get_bond_mask(r, max_natoms):
    # construct smiles to the whole reaction part in a single chemical formula
    rmol = Chem.MolFromSmiles(r)
    n_atoms = rmol.GetNumAtoms()
    bond_mask = np.zeros((max_natoms, max_natoms, len(bo_to_index)))
    for i in range(max_natoms):
        for j in range(max_natoms):
            for k in range(len(bo_to_index)):
                # padding the rest part with 1000
                if i == j or i >= n_atoms or j >= n_atoms:
                    bond_mask[i, j, k] = 10000  # mask
    return bond_mask


def get_feature_batch(r_list):
    max_natoms = 0
    for r in r_list:
        rmol = Chem.MolFromSmiles(r)
        if rmol.GetNumAtoms() > max_natoms:
            max_natoms = rmol.GetNumAtoms()

    binary = []
    bond_mask = []
    for r in r_list:
        # pack binary
        binary.append(get_bin_feature(r, max_natoms))
        # pack bond_mask
        bond_mask.append(get_bond_mask(r, max_natoms))

    return np.array(binary), np.array(bond_mask)
